from utils.inference_dataset import InferenceDataset
import sys
sys.path.append(".")
sys.path.append("..")
from models.psp import pSp
from utils.test_options import TestOptions
from utils.common import tensor2im, log_input_image
from configs import data_configs
import os
from argparse import Namespace

from tqdm import tqdm
import time
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader

def run():
	test_opts = TestOptions().parse()

	if test_opts.resize_factors is not None:
		assert len(test_opts.resize_factors.split(
			',')) == 1, "When running inference, provide a single downsampling factor!"
		out_path_results = os.path.join(test_opts.exp_dir, 'inference_results',
		                                'downsampling_{}'.format(test_opts.resize_factors))
		out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled',
                                  'downsampling_{}'.format(test_opts.resize_factors))
	else:
		out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
		out_path_coupled = os.path.join(test_opts.exp_dir, 'inference_coupled')

	os.makedirs(out_path_results, exist_ok=True)
	os.makedirs(out_path_coupled, exist_ok=True)

	# update test options with options used during training
	ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
	opts = ckpt['opts']
	opts.update(vars(test_opts))
	if 'learn_in_w' not in opts:
		opts['learn_in_w'] = False
	opts = Namespace(**opts)

	net = pSp(opts)
	net.eval()
	net.cuda()

	print('Loading dataset for {}'.format(opts.dataset_type))
	dataset_args = data_configs.DATASETS[opts.dataset_type]
	transforms_dict = dataset_args['transforms'](opts).get_transforms()
	dataset = InferenceDataset(root=opts.data_path,
	                           transform=transforms_dict['transform_inference'],
	                           opts=opts)
	dataloader = DataLoader(dataset,
	                        batch_size=opts.test_batch_size,
	                        shuffle=False,
	                        num_workers=int(opts.test_workers),
	                        drop_last=True)

	if opts.n_images is None:
		opts.n_images = len(dataset)

	global_i = 0
	global_time = []
	for input_batch in tqdm(dataloader):
		if global_i >= opts.n_images:
			break
		with torch.no_grad():
			input_cuda = input_batch.cuda().float()
			tic = time.time()
			result_batch, code_batch = run_on_batch(input_cuda, net, opts)
			torch.save(code_batch, f'{out_path_results}/inverted_codes.pt')
			latent_codes = code_batch.detach().cpu().numpy()
			np.save(f'{out_path_results}/inverted_codes.npy',
		        np.concatenate(latent_codes, axis=0))
			toc = time.time()
			global_time.append(toc - tic)

		for i in range(opts.test_batch_size):
			result = tensor2im(result_batch[i])
			im_path = dataset.paths[global_i]

			if opts.couple_outputs or global_i % 100 == 0:
				input_im = log_input_image(input_batch[i], opts)
				resize_amount = (256, 256) if opts.resize_outputs else (1024, 1024)
				if opts.resize_factors is not None:
					# for super resolution, save the original, down-sampled, and output
					source = Image.open(im_path)
					res = np.concatenate([np.array(source.resize(resize_amount)),
                                            np.array(input_im.resize(
                                            	resize_amount, resample=Image.NEAREST)),
                                            np.array(result.resize(resize_amount))], axis=1)
				else:
					# otherwise, save the original and output
					res = np.concatenate([np.array(input_im.resize(resize_amount)),
                                            np.array(result.resize(resize_amount))], axis=1)
				Image.fromarray(res).save(os.path.join(
					out_path_coupled, os.path.basename(im_path)))

			im_save_path = os.path.join(out_path_results, os.path.basename(im_path))
			Image.fromarray(np.array(result)).save(im_save_path)

			global_i += 1

	stats_path = os.path.join(opts.exp_dir, 'stats.txt')
	result_str = 'Runtime {:.4f}+-{:.4f}'.format(
		np.mean(global_time), np.std(global_time))
	print(result_str)

	with open(stats_path, 'w') as f:
		f.write(result_str)


def run_on_batch(inputs, net, opts):
	if opts.latent_mask is None:
		result_batch, code_batch = net(inputs, randomize_noise=False,
		                   resize=opts.resize_outputs, return_latents=True)
	else:
		latent_mask = [int(l) for l in opts.latent_mask.split(",")]
		result_batch = []
		code_batch = []
		for image_idx, input_image in enumerate(inputs):
			# get latent vector to inject into our input image
			vec_to_inject = np.random.randn(1, 512).astype('float32')
			_, latent_to_inject = net(torch.from_numpy(vec_to_inject).to("cuda"),
			                          input_code=True,
			                          return_latents=True)
			# get output image with injected style vector
			res, cod = net(input_image.unsqueeze(0).to("cuda").float(),
			          latent_mask=latent_mask,
			          inject_latent=latent_to_inject,
			          alpha=opts.mix_alpha,
                            resize=opts.resize_outputs, return_latents=True)
			result_batch.append(res)
			code_batch.append(cod)
		result_batch = torch.cat(result_batch, dim=0)
		code_batch = torch.cat(code_batch, dim=0)
	return result_batch, code_batch

if __name__ == '__main__':
	run()
